#define ARMA_DONT_PRINT_ERRORS
#include <RcppArmadillo.h>

// Declare functions defined in other files
// First, functions defined (& documented) in utils.cpp
arma::mat lag(const arma::mat& x, const int l = 1);
arma::mat embed(const arma::mat& x, const int p = 1);
arma::mat lambda_tilde(const arma::mat& x, const arma::mat& beta,
                       const arma::mat& D);
// Now functions defined (& documented) in moments.cpp
arma::mat bapvar_mean(const arma::mat& ylags, const arma::mat& A,
                      const arma::mat& lambda_tilde_mat);
arma::mat bapvar_var(const arma::mat& ylags, const arma::mat& A,
                     const arma::mat& lambda_tilde_mat, const arma::mat& D);

// [[Rcpp::export(.dynamics)]]
Rcpp::List dynamics(const arma::mat& samples, const arma::mat& y,
                    const arma::mat& x, const int p, const int h,
                    const bool verbose, const int check_nth,
                    const bool scale_by_mean) {
    int m = y.n_cols;   // # of equations
    int n = y.n_rows;   // # of observations
    int K = x.n_cols;   // # of exogenous variables (including intercept)
    int m2 = pow(m, 2); // m squared
    int n_A = m2 * p;   // # of elements in lag coef matrix
    int n_C = m2;       // # of elements in contemporary correlation matrix
    int n_D = m2;       // # of elements in covariance matrix for b
    int n_samples = samples.n_rows;   // # of posterior samples
    int NN = samples.n_cols;          // total # of "parameters"
    arma::mat ylags = embed(y, p);    // matrix of lag observations
    arma::mat this_beta(K, m);        // holds one beta draw
    arma::mat this_A(m * p, m);       // holds one A draw
    arma::mat this_D(m, m);           // holds one D draw
    arma::mat lambda_tilde_mat(n, m); // see utils.cpp
    arma::mat this_mean(n, m);        // holds draw of conditional mean
    arma::mat this_variance(m, m);    // holds draw of conditional variance
    arma::mat A0(m, m);               // initial shock
    arma::mat tmp_step(m, m);         // additional for current horizon step
    arma::mat cum_tmp_step_sq(m, m);  // cumulative at each horizon step
    arma::mat lambda_means(m, m);
    arma::cube irf_samples(n_samples, m*m, h);  // result object (IRF samples)
    arma::cube fevd_samples(n_samples, m*m, h); // result object (FEVD samples)
    // Now for every posterior sample,
    for ( int t = 0; t < n_samples; ++t ) {
        // (periodically check for user interrupt & update progress if verbose)
        if ( (t+1) % check_nth == 0 ) {
            Rcpp::checkUserInterrupt();
            if ( verbose ) {
                Rcpp::Rcout << "Computing dynamics for sample " << t+1 << "\n";
            }
        }
        // store the parameter draws in the proper objects
        for ( int j = 0; j < m; ++j ) { // lag coefficient matrix
            for ( int k = 0; k < (m * p); ++k ) {
                this_A(k, j) = samples(t, k + j*m);
            }
            for ( int k = 0; k < m; ++k ) { // b covariance matrix
                this_D(k, j) = samples(t, k + j*m + n_A + n_C);
            }
            for ( int k = 0; k < K; ++k ) { // exogenous predictor coef matrix
                this_beta(k, j) = samples(t, k + j*K + n_A + n_C + n_D);
            }
        }
        // Now we can get the mean, variance, and initial shock
        lambda_tilde_mat = lambda_tilde(x, this_beta, this_D);
        this_variance = bapvar_var(ylags, this_A, lambda_tilde_mat, this_D);
        try { // we have to make sure we don't have numerical stability issues
            A0 = arma::chol(this_variance);
        }
        catch (...) {
            arma::rowvec r(m*m);
            r.fill(R_NaReal);
            for ( int step = 0; step < h; ++step ) {
                for ( int j = 0; j < irf_samples.n_cols; ++j ) {
                    irf_samples(t, j, step) = r(j);
                    fevd_samples(t, j, step) = r(j);
                }
            }
            continue;
        }
        if ( scale_by_mean ) {
            this_mean = bapvar_mean(ylags, this_A, lambda_tilde_mat);
            for ( int j = 0; j < m; ++j ) {
                double this_colmean = arma::mean(this_mean.col(j));
                for ( int i = 0; i < m; ++i ) {
                    lambda_means(j, i) = this_colmean;
                }
            }
            A0 = A0 % lambda_means;
        }
        tmp_step = A0;
        cum_tmp_step_sq = arma::square(tmp_step);
        for ( int step = 0; step < h; ++step ) {
            for ( int j = 0; j < m; ++j ) {
                double total_fev = arma::sum(cum_tmp_step_sq.col(j));
                for ( int i = 0; i < m; ++i ) {
                    irf_samples(t, i + j*m, step) = tmp_step(i, j);
                    double this_fevd = cum_tmp_step_sq(i, j) / total_fev;
                    fevd_samples(t, i + j*m, step) = this_fevd;
                }
            }
            tmp_step = tmp_step * this_A;
            cum_tmp_step_sq = cum_tmp_step_sq + arma::square(tmp_step);
        }
    }
    return Rcpp::List::create(Rcpp::_["irf"] = irf_samples,
                              Rcpp::_["fevd"] = fevd_samples);
}

// [[Rcpp::export(.unit_dynamics)]]
Rcpp::List unit_dynamics(const arma::mat& samples, const arma::mat& y,
                         const arma::mat& x, const int p, const int h,
                         const bool verbose, const int check_nth,
                         const bool scale_by_mean) {
    int m = y.n_cols;
    int n = y.n_rows;
    int K = x.n_cols;
    int m2 = pow(m, 2);
    int n_A = m2 * p;
    int n_C = m2;
    int n_D = m2;
    int n_samples = samples.n_rows;
    int NN = samples.n_cols;
    arma::mat ylags = embed(y, p);
    arma::mat this_beta(K, m);
    arma::mat this_A(m * p, m);
    arma::mat this_D(m, m);
    arma::mat lambda_tilde_mat(n, m);
    arma::mat this_mean(n, m);
    arma::mat this_variance(m, m);
    arma::mat A0(m, m);
    arma::mat tmp_step(m, m);
    arma::mat cum_tmp_step_sq(m, m);
    arma::mat lambda_means(m, m);
    arma::cube irf_samples(n_samples, m*m, h);
    arma::cube fevd_samples(n_samples, m*m, h);
    for ( int t = 0; t < n_samples; ++t ) {
        if ( (t+1) % check_nth == 0 ) {
            Rcpp::checkUserInterrupt();
            if ( verbose ) {
                Rcpp::Rcout << "Computing dynamics for sample " << t+1 << "\n";
            }
        }
        for ( int j = 0; j < m; ++j ) {
            for ( int k = 0; k < (m * p); ++k ) {
                this_A(k, j) = samples(t, k + j*m);
            }
            for ( int k = 0; k < m; ++k ) {
                this_D(k, j) = samples(t, k + j*m + n_A + n_C);
            }
            for ( int k = 0; k < K; ++k ) {
                this_beta(k, j) = samples(t, k + j*K + n_A + n_C + n_D);
            }
        }
        lambda_tilde_mat = lambda_tilde(x, this_beta, this_D);
        this_variance = bapvar_var(ylags, this_A, lambda_tilde_mat, this_D);
        A0 = arma::eye(m, m);
        if ( scale_by_mean ) {
            this_mean = bapvar_mean(ylags, this_A, lambda_tilde_mat);
            for ( int j = 0; j < m; ++j ) {
                double this_colmean = arma::mean(this_mean.col(j));
                for ( int i = 0; i < m; ++i ) {
                    lambda_means(j, i) = this_colmean;
                }
            }
            A0 = A0 % lambda_means;
        }
        tmp_step = A0;
        cum_tmp_step_sq = arma::square(tmp_step);
        for ( int step = 0; step < h; ++step ) {
            for ( int j = 0; j < m; ++j ) {
                double total_fev = arma::sum(cum_tmp_step_sq.col(j));
                for ( int i = 0; i < m; ++i ) {
                    irf_samples(t, i + j*m, step) = tmp_step(i, j);
                    double this_fevd = cum_tmp_step_sq(i, j) / total_fev;
                    fevd_samples(t, i + j*m, step) = this_fevd;
                }
            }
            tmp_step = tmp_step * this_A;
            cum_tmp_step_sq = cum_tmp_step_sq + arma::square(tmp_step);
        }
    }
    return Rcpp::List::create(Rcpp::_["irf"] = irf_samples,
                              Rcpp::_["fevd"] = fevd_samples);
}

